{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00,  1.23it/s]\n"
     ]
    }
   ],
   "source": [
    "from modeling_chatglm import ChatGLMForConditionalGeneration\n",
    "import torch\n",
    "\n",
    "\n",
    "torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
    "model = ChatGLMForConditionalGeneration.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True, device_map='auto')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mymusise/pro/ChatGLM-Tuning/peft/tuners/lora.py:175: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "peft_path = \"output/adapter_model.bin\"\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, inference_mode=True,\n",
    "    r=8,\n",
    "    lora_alpha=32, lora_dropout=0.1\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.load_state_dict(torch.load(peft_path), strict=False)\n",
    "torch.set_default_tensor_type(torch.cuda.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "instructions = json.load(open(\"data/alpaca_data.json\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mymusise/pro/ChatGLM-Tuning/transformers/generation/utils.py:1374: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Instruction: Give three tips for staying healthy.\n",
      "Answer: 1. Eat a balanced diet.\n",
      "2. Get regular exercise.\n",
      "3. Stay hydrated.\n",
      "### 1.Answer:\n",
      " 1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n",
      "2. Exercise regularly to keep your body active and strong. \n",
      "3. Get enough sleep and maintain a consistent sleep schedule. \n",
      "\n",
      "\n",
      "Instruction: What are the three primary colors?\n",
      "Answer: The three primary colors are red, blue, and yellow.\n",
      "### 2.Answer:\n",
      " The three primary colors are red, blue, and yellow. \n",
      "\n",
      "\n",
      "Instruction: Describe the structure of an atom.\n",
      "Answer: An atom is a small particle of an element, containing a small number of particles of the element. Each particle is made up of a single electron, which is surrounded by a cloud of negative charge. The cloud of negative charge is surrounded by a cloud of positive charge, which creates a neutral cloud. The neutral cloud is made up of a cloud of negative charge, which creates a cloud of positive charge. This process continues until all the atoms have been neutralized.\n",
      "### 3.Answer:\n",
      " An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom. \n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "answers = []\n",
    "from cover_alpaca2jsonl import format_example\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for idx, item in enumerate(instructions[:3]):\n",
    "        feature = format_example(item)\n",
    "        input_text = feature['context']\n",
    "        ids = tokenizer.encode(input_text)\n",
    "        input_ids = torch.LongTensor([ids])\n",
    "        out = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_length=150,\n",
    "            do_sample=False,\n",
    "            temperature=0\n",
    "        )\n",
    "        out_text = tokenizer.decode(out[0])\n",
    "        answer = out_text.replace(input_text, \"\").replace(\"\\nEND\", \"\").strip()\n",
    "        item['infer_answer'] = answer\n",
    "        print(out_text)\n",
    "        print(f\"### {idx+1}.Answer:\\n\", item.get('output'), '\\n\\n')\n",
    "        answers.append({'index': idx, **item})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "25273a2a68c96ebac13d7fb9e0db516f9be0772777a0507fe06d682a441a3ba7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
